Import all the Dependencies¶

In [41]:
import tensorflow as tf
from tensorflow.keras import models, layers
import matplotlib.pyplot as plt
from IPython.display import HTML

Set all the Constants¶

In [42]:
BATCH_SIZE = 32
IMAGE_SIZE = 256
CHANNELS=3
EPOCHS=100

Import data into tensorflow dataset object¶

In [43]:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "PlantVillage copy",
    seed=123,
    shuffle=True,
    image_size=(IMAGE_SIZE,IMAGE_SIZE),
    batch_size=BATCH_SIZE
)
Found 2152 files belonging to 3 classes.
In [44]:
class_names = dataset.class_names
class_names
Out[44]:
['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy']
In [45]:
for image_batch, labels_batch in dataset.take(1):
    print(image_batch.shape)
    print(labels_batch.numpy())
(32, 256, 256, 3)
[1 1 1 0 0 0 0 0 1 1 1 1 0 1 0 1 1 1 0 1 0 1 0 0 1 0 0 1 1 2 0 0]

As you can see above, each element in the dataset is a tuple. First element is a batch of 32 elements of images. Second element is a batch of 32 elements of class labels

Visualize some of the images from our dataset¶

In [46]:
plt.figure(figsize=(10, 10))
for image_batch, labels_batch in dataset.take(1):
    for i in range(12):
        ax = plt.subplot(3, 4, i + 1)
        plt.imshow(image_batch[i].numpy().astype("uint8"))
        plt.title(class_names[labels_batch[i]])
        plt.axis("off")
In [47]:
def get_dataset_partitions_tf(ds, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
    assert (train_split + test_split + val_split) == 1
    
    ds_size = len(ds)
    
    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=12)
    
    train_size = int(train_split * ds_size)
    val_size = int(val_split * ds_size)
    
    train_ds = ds.take(train_size)    
    val_ds = ds.skip(train_size).take(val_size)
    test_ds = ds.skip(train_size).skip(val_size)
    
    return train_ds, val_ds, test_ds
In [48]:
train_ds, val_ds, test_ds = get_dataset_partitions_tf(dataset)

Cache, Shuffle, and Prefetch the Dataset¶

In [49]:
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)

Building the Model¶

In [50]:
resize_and_rescale = tf.keras.Sequential([
  layers.experimental.preprocessing.Resizing(IMAGE_SIZE, IMAGE_SIZE),
  layers.experimental.preprocessing.Rescaling(1./255),
])

Data Augmentation¶

In [51]:
data_augmentation = tf.keras.Sequential([
  layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  layers.experimental.preprocessing.RandomRotation(0.2),
])

Applying Data Augmentation to Train Dataset¶

In [52]:
train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y)
).prefetch(buffer_size=tf.data.AUTOTUNE)

Model Architecture¶

We use a CNN coupled with a Softmax activation in the output layer. We also add the initial layers for resizing, normalization and Data Augmentation.

In [53]:
input_shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)
n_classes = 3

model = models.Sequential([
    resize_and_rescale,
    layers.Conv2D(32, kernel_size = (3,3), activation='relu', input_shape=input_shape),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(n_classes, activation='softmax'),
])

model.build(input_shape=input_shape)
In [54]:
model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sequential_3 (Sequential)   (32, 256, 256, 3)         0         
                                                                 
 conv2d_6 (Conv2D)           (32, 254, 254, 32)        896       
                                                                 
 max_pooling2d_6 (MaxPoolin  (32, 127, 127, 32)        0         
 g2D)                                                            
                                                                 
 conv2d_7 (Conv2D)           (32, 125, 125, 64)        18496     
                                                                 
 max_pooling2d_7 (MaxPoolin  (32, 62, 62, 64)          0         
 g2D)                                                            
                                                                 
 conv2d_8 (Conv2D)           (32, 60, 60, 64)          36928     
                                                                 
 max_pooling2d_8 (MaxPoolin  (32, 30, 30, 64)          0         
 g2D)                                                            
                                                                 
 conv2d_9 (Conv2D)           (32, 28, 28, 64)          36928     
                                                                 
 max_pooling2d_9 (MaxPoolin  (32, 14, 14, 64)          0         
 g2D)                                                            
                                                                 
 conv2d_10 (Conv2D)          (32, 12, 12, 64)          36928     
                                                                 
 max_pooling2d_10 (MaxPooli  (32, 6, 6, 64)            0         
 ng2D)                                                           
                                                                 
 conv2d_11 (Conv2D)          (32, 4, 4, 64)            36928     
                                                                 
 max_pooling2d_11 (MaxPooli  (32, 2, 2, 64)            0         
 ng2D)                                                           
                                                                 
 flatten_1 (Flatten)         (32, 256)                 0         
                                                                 
 dense_2 (Dense)             (32, 64)                  16448     
                                                                 
 dense_3 (Dense)             (32, 3)                   195       
                                                                 
=================================================================
Total params: 183747 (717.76 KB)
Trainable params: 183747 (717.76 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Compiling the Model¶

We use adam Optimizer, SparseCategoricalCrossentropy for losses, accuracy as a metric

In [55]:
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=['accuracy']
)
In [56]:
history = model.fit(
    train_ds,
    batch_size=BATCH_SIZE,
    validation_data=val_ds,
    verbose=1,
    epochs=100,
)
Epoch 1/100
54/54 [==============================] - 25s 442ms/step - loss: 0.8767 - accuracy: 0.5318 - val_loss: 0.8879 - val_accuracy: 0.5156
Epoch 2/100
54/54 [==============================] - 22s 414ms/step - loss: 0.6209 - accuracy: 0.7274 - val_loss: 0.6698 - val_accuracy: 0.6875
Epoch 3/100
54/54 [==============================] - 23s 416ms/step - loss: 0.3941 - accuracy: 0.8339 - val_loss: 0.5220 - val_accuracy: 0.7812
Epoch 4/100
54/54 [==============================] - 22s 414ms/step - loss: 0.2997 - accuracy: 0.8796 - val_loss: 0.2591 - val_accuracy: 0.8698
Epoch 5/100
54/54 [==============================] - 22s 415ms/step - loss: 0.2731 - accuracy: 0.8970 - val_loss: 0.2757 - val_accuracy: 0.8698
Epoch 6/100
54/54 [==============================] - 22s 414ms/step - loss: 0.2317 - accuracy: 0.9080 - val_loss: 0.2344 - val_accuracy: 0.9010
Epoch 7/100
54/54 [==============================] - 22s 413ms/step - loss: 0.2171 - accuracy: 0.9161 - val_loss: 0.2930 - val_accuracy: 0.8802
Epoch 8/100
54/54 [==============================] - 22s 413ms/step - loss: 0.2399 - accuracy: 0.9068 - val_loss: 0.1825 - val_accuracy: 0.9375
Epoch 9/100
54/54 [==============================] - 22s 415ms/step - loss: 0.1648 - accuracy: 0.9340 - val_loss: 0.1831 - val_accuracy: 0.9375
Epoch 10/100
54/54 [==============================] - 22s 414ms/step - loss: 0.1634 - accuracy: 0.9387 - val_loss: 0.1821 - val_accuracy: 0.9219
Epoch 11/100
54/54 [==============================] - 22s 412ms/step - loss: 0.1598 - accuracy: 0.9398 - val_loss: 0.1714 - val_accuracy: 0.9219
Epoch 12/100
54/54 [==============================] - 22s 414ms/step - loss: 0.1548 - accuracy: 0.9473 - val_loss: 0.1524 - val_accuracy: 0.9271
Epoch 13/100
54/54 [==============================] - 22s 412ms/step - loss: 0.1380 - accuracy: 0.9560 - val_loss: 0.1000 - val_accuracy: 0.9583
Epoch 14/100
54/54 [==============================] - 22s 412ms/step - loss: 0.1397 - accuracy: 0.9485 - val_loss: 0.0516 - val_accuracy: 0.9792
Epoch 15/100
54/54 [==============================] - 22s 412ms/step - loss: 0.1057 - accuracy: 0.9595 - val_loss: 0.1098 - val_accuracy: 0.9531
Epoch 16/100
54/54 [==============================] - 22s 413ms/step - loss: 0.0843 - accuracy: 0.9693 - val_loss: 0.0876 - val_accuracy: 0.9635
Epoch 17/100
54/54 [==============================] - 22s 413ms/step - loss: 0.1111 - accuracy: 0.9525 - val_loss: 0.1007 - val_accuracy: 0.9583
Epoch 18/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0823 - accuracy: 0.9664 - val_loss: 0.1723 - val_accuracy: 0.9323
Epoch 19/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0743 - accuracy: 0.9722 - val_loss: 0.0745 - val_accuracy: 0.9583
Epoch 20/100
54/54 [==============================] - 22s 413ms/step - loss: 0.1181 - accuracy: 0.9485 - val_loss: 0.0810 - val_accuracy: 0.9792
Epoch 21/100
54/54 [==============================] - 22s 413ms/step - loss: 0.0670 - accuracy: 0.9797 - val_loss: 0.3638 - val_accuracy: 0.8698
Epoch 22/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0760 - accuracy: 0.9705 - val_loss: 0.4486 - val_accuracy: 0.8594
Epoch 23/100
54/54 [==============================] - 22s 413ms/step - loss: 0.0510 - accuracy: 0.9826 - val_loss: 0.4307 - val_accuracy: 0.8750
Epoch 24/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0600 - accuracy: 0.9757 - val_loss: 0.1793 - val_accuracy: 0.9323
Epoch 25/100
54/54 [==============================] - 22s 411ms/step - loss: 0.1002 - accuracy: 0.9635 - val_loss: 0.0318 - val_accuracy: 0.9948
Epoch 26/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0680 - accuracy: 0.9803 - val_loss: 0.0531 - val_accuracy: 0.9896
Epoch 27/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0900 - accuracy: 0.9641 - val_loss: 0.1522 - val_accuracy: 0.9479
Epoch 28/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0866 - accuracy: 0.9740 - val_loss: 0.1593 - val_accuracy: 0.9323
Epoch 29/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0524 - accuracy: 0.9792 - val_loss: 0.1855 - val_accuracy: 0.9271
Epoch 30/100
54/54 [==============================] - 22s 414ms/step - loss: 0.0746 - accuracy: 0.9757 - val_loss: 0.1110 - val_accuracy: 0.9635
Epoch 31/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0995 - accuracy: 0.9589 - val_loss: 0.0992 - val_accuracy: 0.9531
Epoch 32/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0445 - accuracy: 0.9803 - val_loss: 0.2461 - val_accuracy: 0.9010
Epoch 33/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0346 - accuracy: 0.9878 - val_loss: 0.1447 - val_accuracy: 0.9375
Epoch 34/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0437 - accuracy: 0.9826 - val_loss: 0.0952 - val_accuracy: 0.9635
Epoch 35/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0341 - accuracy: 0.9884 - val_loss: 0.3204 - val_accuracy: 0.8958
Epoch 36/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0137 - accuracy: 0.9959 - val_loss: 0.0572 - val_accuracy: 0.9844
Epoch 37/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0702 - accuracy: 0.9740 - val_loss: 0.1208 - val_accuracy: 0.9531
Epoch 38/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0365 - accuracy: 0.9867 - val_loss: 0.5040 - val_accuracy: 0.8854
Epoch 39/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0321 - accuracy: 0.9896 - val_loss: 0.1342 - val_accuracy: 0.9531
Epoch 40/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0279 - accuracy: 0.9913 - val_loss: 0.0303 - val_accuracy: 0.9896
Epoch 41/100
54/54 [==============================] - 22s 414ms/step - loss: 0.0207 - accuracy: 0.9931 - val_loss: 0.0610 - val_accuracy: 0.9688
Epoch 42/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0460 - accuracy: 0.9838 - val_loss: 0.1811 - val_accuracy: 0.9531
Epoch 43/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0231 - accuracy: 0.9913 - val_loss: 0.2575 - val_accuracy: 0.9271
Epoch 44/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0277 - accuracy: 0.9907 - val_loss: 0.0180 - val_accuracy: 0.9896
Epoch 45/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0254 - accuracy: 0.9902 - val_loss: 0.2487 - val_accuracy: 0.9271
Epoch 46/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0141 - accuracy: 0.9948 - val_loss: 0.0307 - val_accuracy: 0.9896
Epoch 47/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0655 - accuracy: 0.9815 - val_loss: 0.0445 - val_accuracy: 0.9844
Epoch 48/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0379 - accuracy: 0.9850 - val_loss: 0.0474 - val_accuracy: 0.9844
Epoch 49/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0229 - accuracy: 0.9907 - val_loss: 0.2333 - val_accuracy: 0.9271
Epoch 50/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0428 - accuracy: 0.9861 - val_loss: 0.1289 - val_accuracy: 0.9635
Epoch 51/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0356 - accuracy: 0.9873 - val_loss: 0.4064 - val_accuracy: 0.8854
Epoch 52/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0614 - accuracy: 0.9774 - val_loss: 0.0456 - val_accuracy: 0.9792
Epoch 53/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0192 - accuracy: 0.9936 - val_loss: 0.0459 - val_accuracy: 0.9844
Epoch 54/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0133 - accuracy: 0.9965 - val_loss: 0.0585 - val_accuracy: 0.9740
Epoch 55/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0153 - accuracy: 0.9931 - val_loss: 0.0731 - val_accuracy: 0.9635
Epoch 56/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0307 - accuracy: 0.9878 - val_loss: 0.0166 - val_accuracy: 0.9948
Epoch 57/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0355 - accuracy: 0.9867 - val_loss: 0.0895 - val_accuracy: 0.9688
Epoch 58/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0174 - accuracy: 0.9931 - val_loss: 0.2074 - val_accuracy: 0.9479
Epoch 59/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0240 - accuracy: 0.9913 - val_loss: 0.1063 - val_accuracy: 0.9583
Epoch 60/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0367 - accuracy: 0.9861 - val_loss: 0.3355 - val_accuracy: 0.9062
Epoch 61/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0296 - accuracy: 0.9896 - val_loss: 0.2719 - val_accuracy: 0.9219
Epoch 62/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0531 - accuracy: 0.9826 - val_loss: 0.0724 - val_accuracy: 0.9688
Epoch 63/100
54/54 [==============================] - 22s 414ms/step - loss: 0.0076 - accuracy: 0.9977 - val_loss: 0.0528 - val_accuracy: 0.9740
Epoch 64/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0222 - accuracy: 0.9913 - val_loss: 0.0450 - val_accuracy: 0.9688
Epoch 65/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0097 - accuracy: 0.9965 - val_loss: 0.1083 - val_accuracy: 0.9583
Epoch 66/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0162 - accuracy: 0.9931 - val_loss: 0.0127 - val_accuracy: 0.9896
Epoch 67/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0106 - accuracy: 0.9959 - val_loss: 0.1542 - val_accuracy: 0.9583
Epoch 68/100
54/54 [==============================] - 22s 407ms/step - loss: 0.0048 - accuracy: 0.9977 - val_loss: 0.0142 - val_accuracy: 0.9948
Epoch 69/100
54/54 [==============================] - 22s 408ms/step - loss: 0.0285 - accuracy: 0.9896 - val_loss: 0.0096 - val_accuracy: 1.0000
Epoch 70/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0295 - accuracy: 0.9890 - val_loss: 0.0076 - val_accuracy: 1.0000
Epoch 71/100
54/54 [==============================] - 22s 407ms/step - loss: 0.0169 - accuracy: 0.9936 - val_loss: 0.0410 - val_accuracy: 0.9792
Epoch 72/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0075 - accuracy: 0.9983 - val_loss: 0.0171 - val_accuracy: 0.9896
Epoch 73/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0208 - accuracy: 0.9936 - val_loss: 0.0236 - val_accuracy: 0.9844
Epoch 74/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0261 - accuracy: 0.9907 - val_loss: 0.0462 - val_accuracy: 0.9740
Epoch 75/100
54/54 [==============================] - 22s 408ms/step - loss: 0.0175 - accuracy: 0.9948 - val_loss: 0.3493 - val_accuracy: 0.9271
Epoch 76/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0454 - accuracy: 0.9844 - val_loss: 0.0297 - val_accuracy: 0.9844
Epoch 77/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0202 - accuracy: 0.9907 - val_loss: 0.0167 - val_accuracy: 0.9896
Epoch 78/100
54/54 [==============================] - 22s 407ms/step - loss: 0.0086 - accuracy: 0.9971 - val_loss: 0.0689 - val_accuracy: 0.9688
Epoch 79/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0138 - accuracy: 0.9954 - val_loss: 0.0069 - val_accuracy: 1.0000
Epoch 80/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0226 - accuracy: 0.9942 - val_loss: 0.0131 - val_accuracy: 0.9948
Epoch 81/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0418 - accuracy: 0.9815 - val_loss: 0.0745 - val_accuracy: 0.9688
Epoch 82/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0204 - accuracy: 0.9919 - val_loss: 0.1079 - val_accuracy: 0.9583
Epoch 83/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0073 - accuracy: 0.9983 - val_loss: 0.0804 - val_accuracy: 0.9792
Epoch 84/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0212 - accuracy: 0.9936 - val_loss: 0.0576 - val_accuracy: 0.9844
Epoch 85/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0278 - accuracy: 0.9902 - val_loss: 0.0225 - val_accuracy: 0.9896
Epoch 86/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0223 - accuracy: 0.9936 - val_loss: 0.0641 - val_accuracy: 0.9688
Epoch 87/100
54/54 [==============================] - 22s 412ms/step - loss: 0.0158 - accuracy: 0.9931 - val_loss: 0.1085 - val_accuracy: 0.9635
Epoch 88/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0186 - accuracy: 0.9942 - val_loss: 0.0695 - val_accuracy: 0.9792
Epoch 89/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0159 - accuracy: 0.9942 - val_loss: 0.0037 - val_accuracy: 1.0000
Epoch 90/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0392 - accuracy: 0.9861 - val_loss: 0.0226 - val_accuracy: 0.9948
Epoch 91/100
54/54 [==============================] - 22s 411ms/step - loss: 0.0030 - accuracy: 0.9994 - val_loss: 0.0444 - val_accuracy: 0.9844
Epoch 92/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0023 - accuracy: 1.0000 - val_loss: 0.0789 - val_accuracy: 0.9792
Epoch 93/100
54/54 [==============================] - 22s 408ms/step - loss: 0.0064 - accuracy: 0.9983 - val_loss: 0.0850 - val_accuracy: 0.9844
Epoch 94/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0291 - accuracy: 0.9884 - val_loss: 0.0295 - val_accuracy: 0.9896
Epoch 95/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0302 - accuracy: 0.9925 - val_loss: 0.0981 - val_accuracy: 0.9740
Epoch 96/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0205 - accuracy: 0.9925 - val_loss: 0.0667 - val_accuracy: 0.9844
Epoch 97/100
54/54 [==============================] - 22s 410ms/step - loss: 0.0335 - accuracy: 0.9844 - val_loss: 0.0784 - val_accuracy: 0.9792
Epoch 98/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0163 - accuracy: 0.9954 - val_loss: 0.0277 - val_accuracy: 0.9896
Epoch 99/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0258 - accuracy: 0.9896 - val_loss: 0.0599 - val_accuracy: 0.9844
Epoch 100/100
54/54 [==============================] - 22s 409ms/step - loss: 0.0227 - accuracy: 0.9919 - val_loss: 0.0191 - val_accuracy: 0.9896
In [57]:
scores = model.evaluate(test_ds)
8/8 [==============================] - 2s 108ms/step - loss: 0.0393 - accuracy: 0.9883
In [58]:
scores
Out[58]:
[0.03934174403548241, 0.98828125]

Plotting the Accuracy and Loss Curves¶

In [59]:
history
Out[59]:
<keras.src.callbacks.History at 0x2a0db8d00>
In [60]:
history.params
Out[60]:
{'verbose': 1, 'epochs': 100, 'steps': 54}
In [61]:
history.history.keys()
Out[61]:
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
In [62]:
history.history['loss'][:5] # show loss for first 5 epochs
Out[62]:
[0.8766953349113464,
 0.620872974395752,
 0.39409253001213074,
 0.299725204706192,
 0.2731475234031677]
In [63]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']
In [64]:
import numpy as np
# predict on the test dataset
y_pred = model.predict(test_ds)
y_pred_classes = np.argmax(y_pred, axis=1)

# extract the true labels
y_true = np.concatenate([y for x, y in test_ds], axis=0)
8/8 [==============================] - 1s 103ms/step
In [65]:
plt.figure(figsize=(16, 6))

# Accuracy plot
plt.subplot(1, 2, 1)
plt.plot(range(EPOCHS), acc, label='Training Accuracy')
plt.plot(range(EPOCHS), val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

# Loss plot
plt.subplot(1, 2, 2)
plt.plot(range(EPOCHS), loss, label='Training Loss')
plt.plot(range(EPOCHS), val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()
In [66]:
# Confusion matrix:
from sklearn.metrics import confusion_matrix

import seaborn as sns
conf_mat = confusion_matrix(y_true, y_pred_classes)
plt.figure(figsize=(10,7))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
In [67]:
from sklearn.metrics import precision_score, recall_score, f1_score


# Assume y_pred is your prediction and y_true is the true labels
# and class_names is a list of your class names.

# If y_pred is not in the form of class labels, convert it to class labels
y_pred_classes = np.argmax(y_pred, axis=1)

precisions = precision_score(y_true, y_pred_classes, average=None)
recalls = recall_score(y_true, y_pred_classes, average=None)
f1_scores = f1_score(y_true, y_pred_classes, average=None)

x = np.arange(len(class_names))  
width = 0.3 

fig, ax = plt.subplots(figsize=(10,5))
rects1 = ax.bar(x - width, precisions, width, label='Precision')
rects2 = ax.bar(x, recalls, width, label='Recall')
rects3 = ax.bar(x + width, f1_scores, width, label='F1 Score')

ax.set_ylabel('Scores')
ax.set_title('Precision, Recall and F1 Score for each class')
ax.set_xticks(x)
ax.set_xticklabels(class_names)
ax.legend()

fig.tight_layout()

plt.show()

# Print precision, recall and f1-score values
print("Precision: ", precisions)
print("Recall: ", recalls)
print("F1-score: ", f1_scores)
Precision:  [0.52727273 0.58267717 0.26315789]
Recall:  [0.52727273 0.578125   0.27777778]
F1-score:  [0.52727273 0.58039216 0.27027027]
In [68]:
#ROC Curve and AUC: (applicable for binary classification)

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from itertools import cycle

# Binarize the output
y_true_bin = label_binarize(y_true, classes=[0, 1, 2])
n_classes = y_true_bin.shape[1]

# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot all ROC curves
plt.figure()
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()

Run prediction on a sample image¶

In [69]:
import numpy as np
for images_batch, labels_batch in test_ds.take(1):
    
    first_image = images_batch[0].numpy().astype('uint8')
    first_label = labels_batch[0].numpy()
    
    print("first image to predict")
    plt.imshow(first_image)
    print("actual label:",class_names[first_label])
    
    batch_prediction = model.predict(images_batch)
    print("predicted label:",class_names[np.argmax(batch_prediction[0])])
first image to predict
actual label: Potato___Early_blight
1/1 [==============================] - 0s 152ms/step
predicted label: Potato___Early_blight

Write a function for inference¶

In [70]:
def predict(model, img):
    img_array = tf.keras.preprocessing.image.img_to_array(images[i].numpy())
    img_array = tf.expand_dims(img_array, 0)

    predictions = model.predict(img_array)

    predicted_class = class_names[np.argmax(predictions[0])]
    confidence = round(100 * (np.max(predictions[0])), 2)
    return predicted_class, confidence

Now run inference on few sample images

In [71]:
plt.figure(figsize=(12, 12))
for images, labels in test_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        
        predicted_class, confidence = predict(model, images[i].numpy())
        actual_class = class_names[labels[i]] 
        
        plt.title(f"Actual: {actual_class},\n Predicted: {predicted_class}.\n Confidence: {confidence}%")
        
        plt.axis("off")
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 15ms/step
1/1 [==============================] - 0s 15ms/step
1/1 [==============================] - 0s 14ms/step
1/1 [==============================] - 0s 14ms/step
1/1 [==============================] - 0s 15ms/step
1/1 [==============================] - 0s 14ms/step
1/1 [==============================] - 0s 13ms/step
1/1 [==============================] - 0s 15ms/step

Saving the Model¶

We append the model to the list of models as a new version

In [ ]:
 
In [72]:
model.save("../models/final23")
INFO:tensorflow:Assets written to: ../models/final23/assets
INFO:tensorflow:Assets written to: ../models/final23/assets
In [ ]: